{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "248097b8",
   "metadata": {},
   "source": [
    "# EasyEdit Example with **Wise**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81b28fda",
   "metadata": {},
   "source": [
    ">Tutorial author:Jizhan Fang（<fangjizhan@zju.deu.cn>）,Ziyan Jiang（<ziy.jiang@outlook.com>）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "753b8801",
   "metadata": {},
   "source": [
    "In this tutorial, we use `Wise` to edit `llama2-7b-chat` model, we hope this tutorial could help you understand how to use the method WISE on LLMs, using the Wise method with the llama2-7b-chat as an example."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b0a1701",
   "metadata": {},
   "source": [
    "## Model Editing\n",
    "\n",
    "Deployed models may still make unpredictable errors. For example, Large Language Models (LLMs) notoriously hallucinate, perpetuate bias, and factually decay, so we should be able to adjust specific behaviors of pre-trained models.\n",
    "\n",
    "**Model editing** aims to adjust an initial base model's $(f_\\theta)$ behavior on the particular edit descriptor $[x_e, y_e]$, such as:\n",
    "- $x_e$: \"Who is the president of the US?\n",
    "- $y_e$: \"Joe Biden.\"\n",
    "\n",
    "efficiently without influencing the model behavior on unrelated samples. The ultimate goal is to create an edited model$(f_\\theta’)$."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75b3b30a",
   "metadata": {},
   "source": [
    "## Method:WISE\n",
    "\n",
    "Paper: [WISE: Rethinking the Knowledge Memory for Lifelong Model Editing of Large Language Models?](http://arxiv.org/pdf/2405.14768)\n",
    "    \n",
    "**WISE**, is an approach for lifelong model editing of Large Language Models (LLMs). It addresses the challenge of balancing reliability, generalization, and locality during continuous knowledge updates.\n",
    "It provides an effective solution for continuous learning and knowledge updating in large language models through its innovative memory management and editing strategies."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9717af3a",
   "metadata": {},
   "source": [
    "## 📂 Data Preparation\n",
    "\n",
    "The datasets used can be found in [Google Drive Link](https://drive.google.com/file/d/1YtQvv4WvTa4rJyDYQR2J-uK8rnrt0kTA/view?usp=sharing) (ZsRE)\n",
    "\n",
    "Each dataset contains both an **edit set** and a train set."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cf48075",
   "metadata": {},
   "source": [
    "## Prepare the runtime environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6356ed23",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/mnt/8t/fangjizhan/EasyEdit\n",
      "data\t    examples  multimodal_edit.py   run_wise_editing.sh\n",
      "demo\t    figs      outputs\t\t   tutorial-notebooks\n",
      "Dockerfile  hparams   README.md\t\t   tutorial.pdf\n",
      "easyeditor  LICENSE   requirements.txt\n",
      "edit.py     logs      run_wise_editing.py\n"
     ]
    }
   ],
   "source": [
    "## Clone Repo\n",
    "#!git clone https://github.com/zjunlp/EasyEdit\n",
    "%cd EasyEdit\n",
    "!ls"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a104cd71",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "E: Could not open lock file /var/lib/dpkg/lock-frontend - open (13: Permission denied)\n",
      "E: Unable to acquire the dpkg frontend lock (/var/lib/dpkg/lock-frontend), are you root?\n",
      "[sudo] password for fjz: \n",
      "[sudo] password for fjz: "
     ]
    }
   ],
   "source": [
    "!apt-get install python3.9\n",
    "!sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1\n",
    "!sudo update-alternatives --config python3\n",
    "!apt-get install python3-pip\n",
    "%pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b039c94a",
   "metadata": {},
   "source": [
    "## Config Method  Parameters"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d553b513",
   "metadata": {},
   "source": [
    "```python\n",
    "alg_name: \"WISE\"\n",
    "model_name: \"./hugging_cache/llama2-7b-chat\"\n",
    "device: 0\n",
    "\n",
    "mask_ratio: 0.2\n",
    "edit_lr: 1.0\n",
    "n_iter: 70\n",
    "norm_constraint: 1.0\n",
    "act_margin: [5.0, 20.0, 10.0] # alpha, beta, gamma\n",
    "act_ratio: 0.88\n",
    "save_freq: 500\n",
    "merge_freq: 1000\n",
    "merge_alg: 'ties'\n",
    "objective_optimization: 'only_label'\n",
    "inner_params:\n",
    "- model.layers[27].mlp.down_proj.weight\n",
    "\n",
    "\n",
    "## alternative: WISE-Merge, WISE-Retrieve\n",
    "\n",
    "# for merge (if merge)\n",
    "densities: 0.53\n",
    "weights: 1.0\n",
    "\n",
    "# for retrieve (if retrieve, pls set to True)\n",
    "retrieve: True\n",
    "replay: False # True --> will replay the past editing instances: see https://arxiv.org/abs/2405.14768 Appendix B.3\n",
    "\n",
    "model_parallel: False\n",
    "\n",
    "# for save and load\n",
    "# save_path: \"./wise_checkpoint/wise.pt\"\n",
    "# load_path: \"./wise_checkpoint/wise.pt\"\n",
    "\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e9aef0a",
   "metadata": {},
   "source": [
    "## Import models & Run\n",
    "\n",
    "### Edit llama2-7b-chat on ZsRE with WISE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7d8b6409",
   "metadata": {},
   "outputs": [],
   "source": [
    "from easyeditor import BaseEditor\n",
    "from easyeditor import WISEHyperParams"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a0ca0f1d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-10-04 01:00:48,693 - easyeditor.editors.editor - INFO - Instantiating model\n",
      "10/04/2024 01:00:48 - INFO - easyeditor.editors.editor -   Instantiating model\n"
     ]
    },
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.004851579666137695,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 2,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e6acc18f7cb149aab1ce9b5bd72529d6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/8t/fangjizhan/anaconda3/envs/easyedit/lib/python3.9/site-packages/torch/_utils.py:776: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
      "  return self.fget.__get__(instance, owner)()\n",
      "You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message\n",
      "2024-10-04 01:00:58,199 - easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "10/04/2024 01:00:58 - INFO - easyeditor.editors.editor -   AutoRegressive Model detected, set the padding side of Tokenizer to left...\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)\n",
      "100%|██████████| 3/3 [00:01<00:00,  2.44it/s]\n",
      "  0%|          | 0/3 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "New weights successfully inserted into model.layers[27].mlp.down_proj.weight\n",
      "Executing WISE algorithm for the update: \n",
      "[What university did Watts Humphrey attend?] -> [University of Michigan]\n",
      "loss 32.304 = 2.304 + 30.0\n",
      "loss 29.505 = 0.73 + 28.775\n",
      "loss 26.213 = 0.325 + 25.888\n",
      "loss 22.984 = 0.193 + 22.791\n",
      "loss 18.047 = 0.155 + 17.893\n",
      "loss 11.574 = 0.141 + 11.434\n",
      "loss 8.464 = 0.126 + 8.338\n",
      "loss 2.708 = 0.108 + 2.6\n",
      "loss 3.717 = 0.095 + 3.622\n",
      "loss 2.341 = 0.081 + 2.26\n",
      "loss 1.143 = 0.068 + 1.075\n",
      "loss 0.061 = 0.061 + 0.0\n",
      "loss 0.053 = 0.053 + 0.0\n",
      "loss 0.046 = 0.046 + 0.0\n",
      "loss 0.042 = 0.042 + 0.0\n",
      "loss 0.038 = 0.038 + 0.0\n",
      "loss 0.035 = 0.035 + 0.0\n",
      "loss 0.032 = 0.032 + 0.0\n",
      "loss 0.03 = 0.03 + 0.0\n",
      "loss 0.028 = 0.028 + 0.0\n",
      "loss 0.026 = 0.026 + 0.0\n",
      "loss 0.024 = 0.024 + 0.0\n",
      "loss 0.023 = 0.023 + 0.0\n",
      "loss 0.022 = 0.022 + 0.0\n",
      "loss 0.021 = 0.021 + 0.0\n",
      "loss 0.02 = 0.02 + 0.0\n",
      "loss 0.019 = 0.019 + 0.0\n",
      "loss 0.018 = 0.018 + 0.0\n",
      "loss 0.017 = 0.017 + 0.0\n",
      "loss 0.017 = 0.017 + 0.0\n",
      "loss 0.016 = 0.016 + 0.0\n",
      "loss 0.015 = 0.015 + 0.0\n",
      "loss 0.015 = 0.015 + 0.0\n",
      "loss 0.014 = 0.014 + 0.0\n",
      "loss 0.014 = 0.014 + 0.0\n",
      "loss 0.013 = 0.013 + 0.0\n",
      "loss 0.013 = 0.013 + 0.0\n",
      "loss 0.012 = 0.012 + 0.0\n",
      "loss 0.012 = 0.012 + 0.0\n",
      "loss 0.012 = 0.012 + 0.0\n",
      "loss 0.011 = 0.011 + 0.0\n",
      "loss 0.011 = 0.011 + 0.0\n",
      "loss 0.011 = 0.011 + 0.0\n",
      "loss 0.01 = 0.01 + 0.0\n",
      "loss 0.01 = 0.01 + 0.0\n",
      "loss 0.01 = 0.01 + 0.0\n",
      "loss 0.01 = 0.01 + 0.0\n",
      "loss 0.009 = 0.009 + 0.0\n",
      "loss 0.009 = 0.009 + 0.0\n",
      "loss 0.009 = 0.009 + 0.0\n",
      "loss 0.009 = 0.009 + 0.0\n",
      "loss 0.009 = 0.009 + 0.0\n",
      "loss 0.008 = 0.008 + 0.0\n",
      "loss 0.008 = 0.008 + 0.0\n",
      "loss 0.008 = 0.008 + 0.0\n",
      "loss 0.008 = 0.008 + 0.0\n",
      "loss 0.008 = 0.008 + 0.0\n",
      "loss 0.008 = 0.008 + 0.0\n",
      "loss 0.007 = 0.007 + 0.0\n",
      "loss 0.007 = 0.007 + 0.0\n",
      "loss 0.007 = 0.007 + 0.0\n",
      "loss 0.007 = 0.007 + 0.0\n",
      "loss 0.007 = 0.007 + 0.0\n",
      "loss 0.007 = 0.007 + 0.0\n",
      "loss 0.007 = 0.007 + 0.0\n",
      "loss 0.007 = 0.007 + 0.0\n",
      "loss 0.006 = 0.006 + 0.0\n",
      "loss 0.006 = 0.006 + 0.0\n",
      "loss 0.006 = 0.006 + 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 33%|███▎      | 1/3 [00:31<01:03, 31.96s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss 0.006 = 0.006 + 0.0\n",
      "Executing WISE algorithm for the update: \n",
      "[Which family does Ramalinaceae belong to?] -> [Lamiinae]\n",
      "loss 21.54 = 5.723 + 15.817\n",
      "loss 14.215 = 1.794 + 12.421\n",
      "loss 15.211 = 1.411 + 13.8\n",
      "loss 6.933 = 1.084 + 5.849\n",
      "loss 2.873 = 0.848 + 2.025\n",
      "loss 1.654 = 0.655 + 0.999\n",
      "loss 1.561 = 0.505 + 1.057\n",
      "loss 0.603 = 0.361 + 0.243\n",
      "loss 0.283 = 0.283 + 0.0\n",
      "loss 0.203 = 0.203 + 0.0\n",
      "loss 0.144 = 0.144 + 0.0\n",
      "loss 0.1 = 0.1 + 0.0\n",
      "loss 0.069 = 0.069 + 0.0\n",
      "loss 0.048 = 0.048 + 0.0\n",
      "loss 0.035 = 0.035 + 0.0\n",
      "loss 0.026 = 0.026 + 0.0\n",
      "loss 0.021 = 0.021 + 0.0\n",
      "loss 0.017 = 0.017 + 0.0\n",
      "loss 0.014 = 0.014 + 0.0\n",
      "loss 0.012 = 0.012 + 0.0\n",
      "loss 0.01 = 0.01 + 0.0\n",
      "loss 0.009 = 0.009 + 0.0\n",
      "loss 0.008 = 0.008 + 0.0\n",
      "loss 0.007 = 0.007 + 0.0\n",
      "loss 0.006 = 0.006 + 0.0\n",
      "loss 0.006 = 0.006 + 0.0\n",
      "loss 0.005 = 0.005 + 0.0\n",
      "loss 0.005 = 0.005 + 0.0\n",
      "loss 0.004 = 0.004 + 0.0\n",
      "loss 0.004 = 0.004 + 0.0\n",
      "loss 0.004 = 0.004 + 0.0\n",
      "loss 0.004 = 0.004 + 0.0\n",
      "loss 0.003 = 0.003 + 0.0\n",
      "loss 0.003 = 0.003 + 0.0\n",
      "loss 0.003 = 0.003 + 0.0\n",
      "loss 0.003 = 0.003 + 0.0\n",
      "loss 0.003 = 0.003 + 0.0\n",
      "loss 0.003 = 0.003 + 0.0\n",
      "loss 0.002 = 0.002 + 0.0\n",
      "loss 0.002 = 0.002 + 0.0\n",
      "loss 0.002 = 0.002 + 0.0\n",
      "loss 0.002 = 0.002 + 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 67%|██████▋   | 2/3 [00:47<00:22, 22.57s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Executing WISE algorithm for the update: \n",
      "[What role does Denny Herzig play in football?] -> [winger]\n",
      "loss 28.15 = 5.43 + 22.719\n",
      "loss 11.253 = 0.499 + 10.754\n",
      "loss 12.445 = 0.1 + 12.345\n",
      "loss 4.612 = 0.056 + 4.556\n",
      "loss 3.108 = 0.034 + 3.073\n",
      "loss 1.19 = 0.022 + 1.168\n",
      "loss 0.098 = 0.018 + 0.08\n",
      "loss 0.013 = 0.013 + 0.0\n",
      "loss 0.011 = 0.011 + 0.0\n",
      "loss 0.01 = 0.01 + 0.0\n",
      "loss 0.008 = 0.008 + 0.0\n",
      "loss 0.008 = 0.008 + 0.0\n",
      "loss 0.007 = 0.007 + 0.0\n",
      "loss 0.006 = 0.006 + 0.0\n",
      "loss 0.006 = 0.006 + 0.0\n",
      "loss 0.005 = 0.005 + 0.0\n",
      "loss 0.005 = 0.005 + 0.0\n",
      "loss 0.004 = 0.004 + 0.0\n",
      "loss 0.004 = 0.004 + 0.0\n",
      "loss 0.004 = 0.004 + 0.0\n",
      "loss 0.004 = 0.004 + 0.0\n",
      "loss 0.003 = 0.003 + 0.0\n",
      "loss 0.003 = 0.003 + 0.0\n",
      "loss 0.003 = 0.003 + 0.0\n",
      "loss 0.003 = 0.003 + 0.0\n",
      "loss 0.003 = 0.003 + 0.0\n",
      "loss 0.003 = 0.003 + 0.0\n",
      "loss 0.002 = 0.002 + 0.0\n",
      "loss 0.002 = 0.002 + 0.0\n",
      "loss 0.002 = 0.002 + 0.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 3/3 [01:01<00:00, 20.43s/it]\n",
      "2024-10-04 01:02:05,944 - easyeditor.editors.editor - INFO - 0 editing: What university did Watts Humphrey attend? -> University of Michigan  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.3333333333333333], 'portability': {}, 'rephrase_acc': [0.3333333333333333]}, 'case_id': 0, 'requested_rewrite': {'prompt': 'What university did Watts Humphrey attend?', 'target_new': 'University of Michigan', 'ground_truth': '<|endoftext|>', 'portability': {}, 'locality': {'neighborhood': {'prompt': 'nq question: who played desmond doss father in hacksaw ridge', 'ground_truth': 'Hugo Weaving'}}, 'subject': 'Watts Humphrey', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\", 'rephrase_prompt': 'What university did Watts Humphrey take part in?'}, 'post': {'rewrite_acc': [1.0], 'locality': {'neighborhood_acc': [1.0]}, 'portability': {}, 'rephrase_acc': [1.0]}}\n",
      "10/04/2024 01:02:05 - INFO - easyeditor.editors.editor -   0 editing: What university did Watts Humphrey attend? -> University of Michigan  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.3333333333333333], 'portability': {}, 'rephrase_acc': [0.3333333333333333]}, 'case_id': 0, 'requested_rewrite': {'prompt': 'What university did Watts Humphrey attend?', 'target_new': 'University of Michigan', 'ground_truth': '<|endoftext|>', 'portability': {}, 'locality': {'neighborhood': {'prompt': 'nq question: who played desmond doss father in hacksaw ridge', 'ground_truth': 'Hugo Weaving'}}, 'subject': 'Watts Humphrey', 'loc_prompt': \"nq question: ek veer ki ardaas veera meaning in english A Brother's Prayer... Veera\", 'rephrase_prompt': 'What university did Watts Humphrey take part in?'}, 'post': {'rewrite_acc': [1.0], 'locality': {'neighborhood_acc': [1.0]}, 'portability': {}, 'rephrase_acc': [1.0]}}\n",
      "2024-10-04 01:02:06,156 - easyeditor.editors.editor - INFO - 1 editing: Which family does Ramalinaceae belong to? -> Lamiinae  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}, 'rephrase_acc': [0.0]}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Which family does Ramalinaceae belong to?', 'target_new': 'Lamiinae', 'ground_truth': '<|endoftext|>', 'portability': {}, 'locality': {'neighborhood': {'prompt': 'nq question: types of skiing in the winter olympics 2018', 'ground_truth': 'Downhill'}}, 'subject': 'Ramalinaceae', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul', 'rephrase_prompt': 'What family are Ramalinaceae?'}, 'post': {'rewrite_acc': [1.0], 'locality': {'neighborhood_acc': [1.0]}, 'portability': {}, 'rephrase_acc': [1.0]}}\n",
      "10/04/2024 01:02:06 - INFO - easyeditor.editors.editor -   1 editing: Which family does Ramalinaceae belong to? -> Lamiinae  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}, 'rephrase_acc': [0.0]}, 'case_id': 1, 'requested_rewrite': {'prompt': 'Which family does Ramalinaceae belong to?', 'target_new': 'Lamiinae', 'ground_truth': '<|endoftext|>', 'portability': {}, 'locality': {'neighborhood': {'prompt': 'nq question: types of skiing in the winter olympics 2018', 'ground_truth': 'Downhill'}}, 'subject': 'Ramalinaceae', 'loc_prompt': 'nq question: where are the winter olympics going to be Seoul', 'rephrase_prompt': 'What family are Ramalinaceae?'}, 'post': {'rewrite_acc': [1.0], 'locality': {'neighborhood_acc': [1.0]}, 'portability': {}, 'rephrase_acc': [1.0]}}\n",
      "2024-10-04 01:02:06,369 - easyeditor.editors.editor - INFO - 2 editing: What role does Denny Herzig play in football? -> winger  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}, 'rephrase_acc': [0.0]}, 'case_id': 2, 'requested_rewrite': {'prompt': 'What role does Denny Herzig play in football?', 'target_new': 'winger', 'ground_truth': '<|endoftext|>', 'portability': {}, 'locality': {'neighborhood': {'prompt': 'nq question: where does aarp fall on the political spectrum', 'ground_truth': 'non-partisan'}}, 'subject': 'Denny Herzig', 'loc_prompt': 'nq question: physician who studies and treats diseases of the endocrine system endocrinologist', 'rephrase_prompt': \"What's Denny Herzig's role in football?\"}, 'post': {'rewrite_acc': [1.0], 'locality': {'neighborhood_acc': [1.0]}, 'portability': {}, 'rephrase_acc': [1.0]}}\n",
      "10/04/2024 01:02:06 - INFO - easyeditor.editors.editor -   2 editing: What role does Denny Herzig play in football? -> winger  \n",
      "\n",
      " {'pre': {'rewrite_acc': [0.0], 'portability': {}, 'rephrase_acc': [0.0]}, 'case_id': 2, 'requested_rewrite': {'prompt': 'What role does Denny Herzig play in football?', 'target_new': 'winger', 'ground_truth': '<|endoftext|>', 'portability': {}, 'locality': {'neighborhood': {'prompt': 'nq question: where does aarp fall on the political spectrum', 'ground_truth': 'non-partisan'}}, 'subject': 'Denny Herzig', 'loc_prompt': 'nq question: physician who studies and treats diseases of the endocrine system endocrinologist', 'rephrase_prompt': \"What's Denny Herzig's role in football?\"}, 'post': {'rewrite_acc': [1.0], 'locality': {'neighborhood_acc': [1.0]}, 'portability': {}, 'rephrase_acc': [1.0]}}\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics Summary:  {'pre': {'rewrite_acc': 0.1111111111111111, 'rephrase_acc': 0.1111111111111111}, 'post': {'rewrite_acc': 1.0, 'rephrase_acc': 1.0, 'locality': {'neighborhood_acc': 1.0}}}\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "K = 3\n",
    "edit_data = json.load(open('./data/ZsRE/zsre_mend_edit.json', 'r', encoding='utf-8'))[:K]\n",
    "loc_data = json.load(open('./data/ZsRE/zsre_mend_train.json', 'r', encoding='utf-8'))[:K]\n",
    "loc_prompts = [edit_data_['loc'] + ' ' + edit_data_['loc_ans'] for edit_data_ in loc_data]\n",
    "\n",
    "prompts = [edit_data_['src'] for edit_data_ in edit_data]\n",
    "subject = [edit_data_['subject'] for edit_data_ in edit_data]\n",
    "rephrase_prompts = [edit_data_['rephrase'] for edit_data_ in edit_data]\n",
    "target_new = [edit_data_['alt'] for edit_data_ in edit_data]\n",
    "locality_prompts = [edit_data_['loc'] for edit_data_ in edit_data]\n",
    "locality_ans = [edit_data_['loc_ans'] for edit_data_ in edit_data]\n",
    "locality_inputs = {\n",
    "    'neighborhood':{\n",
    "        'prompt': locality_prompts,\n",
    "        'ground_truth': locality_ans\n",
    "    },\n",
    "}\n",
    "hparams = WISEHyperParams.from_hparams('./hparams/WISE/llama-7b.yaml')\n",
    "\n",
    "editor = BaseEditor.from_hparams(hparams)\n",
    "metrics, edited_model, _ = editor.edit(\n",
    "    prompts=prompts,\n",
    "    rephrase_prompts=rephrase_prompts,\n",
    "    target_new=target_new,\n",
    "    loc_prompts=loc_prompts,\n",
    "    subject=subject,\n",
    "    locality_inputs=locality_inputs,\n",
    "    sequential_edit=True,\n",
    "    eval_metric='token em'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fe56c3a1",
   "metadata": {},
   "source": [
    "* edit_data: editing instance in edit set.\n",
    "* loc_data: used to provide xi in Equation 5, sampled from the train set.\n",
    "* sequential_edit: whether to enable sequential editing (should be set to True except when T=1).\n",
    "***"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3db4502",
   "metadata": {},
   "source": [
    "### Reliability Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2449d506",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import LlamaTokenizer\n",
    "from transformers import LlamaForCausalLM\n",
    "\n",
    "device = 1\n",
    "model = LlamaForCausalLM.from_pretrained('./hugging_cache/llama2-7b-chat').to(f'cuda:{device}')\n",
    "tokenizer = LlamaTokenizer.from_pretrained('./hugging_cache/llama2-7b-chat')\n",
    "tokenizer.pad_token_id = tokenizer.eos_token_id\n",
    "tokenizer.padding_side='left'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c2acf594",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/json": {
       "ascii": false,
       "bar_format": null,
       "colour": null,
       "elapsed": 0.0031616687774658203,
       "initial": 0,
       "n": 0,
       "ncols": null,
       "nrows": null,
       "postfix": null,
       "prefix": "Loading checkpoint shards",
       "rate": null,
       "total": 2,
       "unit": "it",
       "unit_divisor": 1000,
       "unit_scale": false
      },
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "030ef3451adf4e10b6fa85d05943926a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/mnt/8t/fangjizhan/anaconda3/envs/easyedit/lib/python3.9/site-packages/transformers/tokenization_utils_base.py:2778: UserWarning: `max_length` is ignored when `padding`=`True` and there is no truncation strategy. To pad to max length, use `padding='max_length'`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pre-Edit Outputs:  ['<s>What university did Watts Humphrey attend?\\n\\nWatts', '</s></s><s>Which family does Ramalinaceae belong to?\\n\\nRamalin', '<s>What role does Denny Herzig play in football?\\nWhat is Denny']\n",
      "Post-Edit Outputs:  ['<s>What university did Watts Humphrey attend? University of Michigan University of', '</s></s><s>Which family does Ramalinaceae belong to? Lamiinae University of', '<s>What role does Denny Herzig play in football? winger winger w']\n"
     ]
    }
   ],
   "source": [
    "correct_prompts = ['What university did Watts Humphrey attend?',\n",
    "                'Which family does Ramalinaceae belong to?',\n",
    "                'What role does Denny Herzig play in football?']\n",
    "\n",
    "batch = tokenizer(correct_prompts, return_tensors='pt', padding=True, max_length=30)\n",
    "\n",
    "\n",
    "pre_edit_outputs = model.generate(\n",
    "    input_ids=batch['input_ids'].to(model.device),\n",
    "    attention_mask=batch['attention_mask'].to(model.device),\n",
    "    max_new_tokens=15\n",
    ")\n",
    "post_edit_outputs = edited_model.generate(\n",
    "    input_ids=batch['input_ids'].to(edited_model.device),\n",
    "    attention_mask=batch['attention_mask'].to(edited_model.device),\n",
    "    max_new_tokens=15\n",
    ")\n",
    "\n",
    "max_length = batch['input_ids'].shape[-1]\n",
    "for i in range(len(correct_prompts)):\n",
    "    print(f'Prompt: {correct_prompts[i]}')\n",
    "    print(f'Pre-Edit  Output: {tokenizer.decode( pre_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print(f'Post-Edit Output: {tokenizer.decode(post_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print('--'*50 )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43528147",
   "metadata": {},
   "source": [
    "### Generalization test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4074b583",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pre-Edit Outputs:  ['<s>What university did Watts Humphrey take part in?\\n\\nWatts Humphrey', '</s></s></s></s></s></s><s>What family are Ramalinaceae?\\n\\nRamalinaceae is a', \"<s>What's Denny Herzig's role in football?\\n\\nDenny Herzig is a\"]\n",
      "Post-Edit Outputs:  ['<s>What university did Watts Humphrey take part in? University of Michigan University of Michigan University of', '</s></s></s></s></s></s><s>What family are Ramalinaceae? Lamiinae University of Michigan University of', \"<s>What's Denny Herzig's role in football? winger winger winger winger\"]\n"
     ]
    }
   ],
   "source": [
    "generation_prompts = ['What university did Watts Humphrey take part in?',\n",
    "'What family are Ramalinaceae?',\n",
    "\"What's Denny Herzig's role in football?\"]\n",
    "\n",
    "batch = tokenizer(generation_prompts , return_tensors='pt', padding=True, max_length=30)\n",
    "\n",
    "pre_edit_outputs = model.generate(\n",
    "    input_ids=batch['input_ids'].to(model.device),\n",
    "    attention_mask=batch['attention_mask'].to(model.device),\n",
    "    max_new_tokens=15\n",
    ")\n",
    "post_edit_outputs = edited_model.generate(\n",
    "    input_ids=batch['input_ids'].to(edited_model.device),\n",
    "    attention_mask=batch['attention_mask'].to(edited_model.device),\n",
    "    max_new_tokens=15\n",
    ")\n",
    "\n",
    "max_length = batch['input_ids'].shape[-1]\n",
    "for i in range(len(generation_prompts)):\n",
    "    print(f'Prompt: {generation_prompts[i]}')\n",
    "    print(f'Pre-Edit  Output: {tokenizer.decode( pre_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print(f'Post-Edit Output: {tokenizer.decode(post_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print('--'*50 )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d4c3779",
   "metadata": {},
   "source": [
    "### Locality test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f21404e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pre-Edit Outputs:  ['</s><s>nq question: who played desmond doss father in hacksaw ridge?\\n\\nAnswer: The actor who', '<s>nq question: types of skiing in the winter olympics 2018\\n\\nThere are several types of ski', '</s></s></s></s></s><s>nq question: where does aarp fall on the political spectrum?\\n\\nAnswer: AARP']\n",
      "Post-Edit Outputs:  ['</s><s>nq question: who played desmond doss father in hacksaw ridge?\\n\\nAnswer: The actor who', '<s>nq question: types of skiing in the winter olympics 2018\\n\\nThere are several types of ski', '</s></s></s></s></s><s>nq question: where does aarp fall on the political spectrum?\\n\\nAnswer: AARP']\n"
     ]
    }
   ],
   "source": [
    "locality_prompts = ['nq question: who played desmond doss father in hacksaw ridge',\n",
    "                'nq question: types of skiing in the winter olympics 2018',\n",
    "                'nq question: where does aarp fall on the political spectrum']\n",
    "\n",
    "batch = tokenizer(locality_prompts, return_tensors='pt', padding=True, max_length=30)\n",
    "\n",
    "pre_edit_outputs = model.generate(\n",
    "    input_ids=batch['input_ids'].to(model.device),\n",
    "    attention_mask=batch['attention_mask'].to(model.device),\n",
    "    max_new_tokens=15\n",
    ")\n",
    "post_edit_outputs = edited_model.generate(\n",
    "    input_ids=batch['input_ids'].to(edited_model.device),\n",
    "    attention_mask=batch['attention_mask'].to(edited_model.device),\n",
    "    max_new_tokens=15\n",
    ")\n",
    "\n",
    "max_length = batch['input_ids'].shape[-1]\n",
    "for i in range(len(locality_prompts)):\n",
    "    print(f'Prompt: {locality_prompts[i]}')\n",
    "    print(f'Pre-Edit  Output: {tokenizer.decode( pre_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print(f'Post-Edit Output: {tokenizer.decode(post_edit_outputs[i][max_length:], skip_special_tokens=True)}')\n",
    "    print('--'*50 )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69cdeea5",
   "metadata": {},
   "source": [
    "## Citation\n",
    "If finding this work useful for your research, you can cite it as follows:"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f251d3f8",
   "metadata": {},
   "source": [
    "```bibtex\n",
    "@misc{wang2024wiserethinkingknowledgememory,\n",
    "      title={WISE: Rethinking the Knowledge Memory for Lifelong Model Editing of Large Language Models}, \n",
    "      author={Peng Wang and Zexi Li and Ningyu Zhang and Ziwen Xu and Yunzhi Yao and Yong Jiang and Pengjun Xie and Fei Huang and Huajun Chen},\n",
    "      year={2024},\n",
    "      eprint={2405.14768},\n",
    "      archivePrefix={arXiv},\n",
    "      primaryClass={cs.CL},\n",
    "      url={https://arxiv.org/abs/2405.14768}, \n",
    "}\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "EasyEdit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
